# runners/tpe_runner.py
from __future__ import annotations
import time
import optuna
import ConfigSpace as CS

from objective import Objective
from loggers import ExperimentLogger

from runners.random_runner import _suggest_from_configspace


def run_tpe_optuna(*,
                   seed: int,
                   bench: str,
                   cs: CS.ConfigurationSpace,
                   obj: Objective,
                   budget_n: int,
                   logger: ExperimentLogger,
                   method_name: str = "TPE-Optuna",
                   n_startup_trials: int | None = None,
                   multivariate: bool = True,
                   group: bool = True):
    """
    Run TPE using Optuna.TPESampler.
    Log: n_eval / sim_time (accumulated surrogate runtime) / elapsed_time (actual time) / curr_score / best_score / config
    """
    # Default startup trials: 20% of budget, minimum 10, maximum 100 (adjust as needed)
    if n_startup_trials is None:
        n_startup_trials = max(10, min(100, budget_n // 5))

    # Some Optuna versions do not support group/multivariate parameters, fallback here
    sampler_kwargs = dict(seed=seed, n_startup_trials=n_startup_trials)
    try:
        sampler_kwargs.update(dict(multivariate=multivariate, group=group))
    except TypeError:
        pass  # Ignore if not supported in older versions

    sampler = optuna.samplers.TPESampler(**sampler_kwargs)

    study = optuna.create_study(direction="minimize",
                                sampler=sampler,
                                pruner=optuna.pruners.NopPruner())

    best = float("inf")

    def objective(trial: optuna.Trial):
        cfg = _suggest_from_configspace(trial, cs)
        t0 = time.perf_counter()
        curr, sim_t = obj.evaluate(cfg)
        elapsed = time.perf_counter() - t0
        trial.set_user_attr("sim_time", sim_t)
        trial.set_user_attr("elapsed_time", elapsed)
        trial.set_user_attr("config", cfg)
        return curr

    def cb(study: optuna.Study, trial: optuna.FrozenTrial):
        if trial.value is None:
            return
        nonlocal best
        # Only count successfully completed trials
        n = len([t for t in study.trials if t.value is not None])
        curr = trial.value
        best = min(best, curr)
        logger.log(dict(
            seed=seed, method=method_name, bench=bench,
            n_eval=n,
            sim_time=trial.user_attrs.get("sim_time", 0.0),
            elapsed_time=trial.user_attrs.get("elapsed_time", 0.0),
            best_score=1-best, curr_score=1-curr, 
            config=trial.user_attrs.get("config", {}),
        ))

    study.optimize(objective, n_trials=budget_n, callbacks=[cb], show_progress_bar=False)

